# -*- coding: utf-8 -*-

# Preparation
"""

import pandas as pd
import re

from google.colab import drive

drive.mount('/content/gdrive')

"""

# Model training"""

!pip3 install transformers

"""## Load manifestos

Load corpus data
"""

import requests
import json

root_url = "https://manifesto-project.wzb.eu/tools/"
api_key = ""   #provide own manifesto project api key
dataset_version = "MPDS2022a"

core_dataset_link = f"{root_url}api_get_core.json?key={dataset_version}&api_key={api_key}"
response = requests.get(core_dataset_link)
core_dataset = json.loads(response.content.decode('utf-8'))
core_dataset = pd.DataFrame(core_dataset)
core_dataset.columns = core_dataset.iloc[0]
core_dataset.drop(index=0,inplace=True)

core_dataset.head()

core_dataset["manifesto_id"] = [f"{row['party']}_{row['date']}"  for index, row in core_dataset.iterrows()]

core_dataset.drop(core_dataset[core_dataset["manifesto_id"].isin(["43020_201910","43901_201910"])].index, inplace=True)

core_dataset.drop(core_dataset[core_dataset["manual"] != "5"].index, inplace=True)

core_dataset.head()

relevant_manifesto_ids = core_dataset["manifesto_id"].unique()

list(relevant_manifesto_ids)

text_and_annotations_link = f"{root_url}api_texts_and_annotations.json"
response = requests.post(text_and_annotations_link, data={"keys[]": relevant_manifesto_ids,"version": "2022-1","api_key":api_key})
manifestos = json.loads(response.content.decode('utf-8'))
collect_manifestos = []
for manifesto in manifestos["items"]:
    annotated_text = manifesto["items"]
    annotated_text = {index: entry for index, entry in enumerate(annotated_text)}
    annotated_text_df = pd.DataFrame.from_dict(annotated_text, orient="index")
    if "content" in annotated_text_df.columns:
        annotated_text_df.rename(columns={"content": "text"}, inplace=True)
    annotated_text_df["manifesto_id"] = manifesto["key"]
    annotated_text_df["party"] = manifesto["key"].split("_")[0]
    annotated_text_df["date"] = manifesto["key"].split("_")[1]
    country_code = manifesto["key"].split("_")[0][:-3]
    if annotated_text_df[annotated_text_df["cmp_code"] != "NA"].empty:
        continue
    collect_manifestos.append(annotated_text_df)



corpus_df = pd.concat(collect_manifestos, ignore_index=True, sort=False)

corpus_df

corpus_df["country_code"] = [str(row["party"])[:2] for index, row in corpus_df.iterrows()]

"""## Prepare data"""

country_codes = {
41: "Germany",
42: "Austria",
43: "Switzerland"}

recoding_dict = {"101": "1", "102": "1", "103.1":"1", "103.2":"1","106":"1", "107":"1","109":"1",
                 "108":"2","110":"2",
                 "104":"3","105":"3",
                 "201.1":"4","201.2":"4",
                 "202.1":"5","202.2":"5","202.3":"5","202.4":"5","203":"5","204":"5",
                 "301":"6","302":"6","303":"6","304":"6",
                 "305.1":"7","305.2":"7","305.3":"7","305.4":"7","305.5":"7","305.6":"7",
                 "401":"8","402":"8","403":"8","404":"8","405":"8","406":"8","407":"8","408":"8","409":"8","410":"8","412":"8","413":"8","414":"8","415":"8","416.1":"8","704":"8",
                 "411":"9",
                 "416.2":"10","501":"10",
                 "502":"11",
                 "503":"12","705":"12",
                 "504":"13","505":"13","706":"13",
                 "506":"14","507":"14",
                 "601.1":"15","602.1":"15","603":"15","604":"15","606.1":"15","606.2":"15",
                 "601.2":"16","602.2":"16","607.1":"16","607.2":"16","607.3":"16","608.1":"16","608.2":"16","608.3":"16",
                 "605.1":"17","605.2":"17",
                 "701":"18","702":"18",
                 "703.1":"19","703.2":"19",
                 "000":"20"}

corpus_df = corpus_df.dropna(subset=["cmp_code"])
corpus_df ['cmp_code_recoded'] = corpus_df ['cmp_code'].replace(recoding_dict)
corpus_df ['cmp_code_recoded'].replace({"H": None, "": None,"NA":None}, inplace=True)
corpus_df  = corpus_df.dropna(subset=["cmp_code_recoded"])

corpus_df.cmp_code.unique()

# delete CEE codes in data set (e.g. "6014")
corpus_df[corpus_df["cmp_code_recoded"] == "6014"]

corpus_df = corpus_df[corpus_df["cmp_code_recoded"] != "6014"]

corpus_df

corpus_df.shape

"""## Tokenize Train and Test data"""

from transformers import BertTokenizer, BertModel,BertForSequenceClassification

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

training_sequences = tokenizer(list(corpus_df["text"]), truncation=True, padding=True, max_length=512)

# create category-index mapping dicts

cmp_codes = [str(cmp_code) for cmp_code in range(1,21)]

index_to_cmp = {index: str(cmp_code) for index, cmp_code in enumerate(cmp_codes)}
cmp_to_index = {index_to_cmp[index]: index for index in index_to_cmp }

training_labels = [cmp_to_index[cmp_code] for cmp_code in corpus_df["cmp_code_recoded"]]

import torch

#create custom torch dataset type
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.sequences.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = MyDataset(training_sequences, training_labels)

from torch.utils.data import random_split

# randomly split 10% of training dataset for validation during training

train_dataset, valid_dataset = random_split(train_dataset, (round(len(train_dataset) * 0.9),round(len(train_dataset) * 0.1 ) ))

"""## Train Model"""

from transformers import BertForSequenceClassification


model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels = len(cmp_codes))

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    num_train_epochs=1,              
    per_device_train_batch_size=4,  
    per_device_eval_batch_size=16,
    gradient_accumulation_steps= 4,   
    learning_rate=3e-5,              
    warmup_steps=100,                
    weight_decay=0.01,               
    output_dir='./results',          
    logging_dir='./logs',            
    logging_steps=2000,              
    evaluation_strategy = "epoch",
    save_strategy = "no"      
)

from sklearn.metrics import accuracy_score

def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  acc = accuracy_score(labels, preds)
  return {
      'accuracy': acc,
  }

trainer = Trainer(
    model=model,                        
    args=training_args,                  
    train_dataset=train_dataset,        
    eval_dataset=valid_dataset,       
    compute_metrics=compute_metrics      
)

trainer.train()

"""save model"""

cached_model_directory_name = "bert-multiling_no_context_512tokens_V2" # provide correct path

trainer.save_model(cached_model_directory_name)

## load finetuned model

from transformers import BertForSequenceClassification
from transformers import Trainer

model = BertForSequenceClassification.from_pretrained(cached_model_directory_name)

trainer = Trainer(model=model)

"""# Cross-domain classification

## Press releases

load data
"""

data_path = "Corpus_Pressreleases.csv" # provide correct path

pressrelease_df = pd.read_csv(data_path,  encoding="cp1252", index_col=0)

pressrelease_df.head()

pressrelease_df["date"] = pd.to_datetime(pressrelease_df["date"])

pressrelease_df["type"] = ["termin" in text or "Termin" in text or "Aviso" in text or "AVISO" in text for text in pressrelease_df["text"]]

pressrelease_df["type"] = pressrelease_df["type"].replace({True: "not policy-related", False: "policy-related"  })

#splitting datasets
pressrelease_non_policy_df = pressrelease_df[pressrelease_df["type"] == "not policy-related"]
pressrelease_policy_df = pressrelease_df[pressrelease_df["type"] == "policy-related"]

pressrelease_non_policy_df["issue"] = "NA"

text_label_dict = {"1": "Foreign_Affairs",
                   "2": "European_Union",
                   "3": "Defense",
                   "4": "Freedom",
                   "5": "Democracy",
                   "6": "Political_System",
                   "7": "Political_Authority",
                   "8": "Economy",
                   "9": "Technology_and_Infrastructure",
                   "10": "Environment",
                   "11": "Culture",
                   "12": "Equality",
                   "13": "Welfare_State",
                   "14": "Education",
                   "15": "Society_and_Values",
                   "16": "Immigration",
                   "17": "Law_and_Order",
                   "18": "Labour",
                   "19": "Agriculture",
                   "20": "NA"}

pressrelease_policy_df['text'] = pressrelease_policy_df['text'].fillna("")

"""convert data to model input conform"""

pressrelease_document_sequences = tokenizer(list(pressrelease_policy_df["text"]), truncation=True, padding=True, max_length=512)

class MyDatasetRaw(torch.utils.data.Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.sequences.items()}
        return item

    def __len__(self):
        return len(self.sequences['input_ids'])

pressrelease_whole_document_dataset = MyDatasetRaw(pressrelease_document_sequences)

"""predict documents of whole press release df"""

predicted_results_pressrelease_documents = trainer.predict(pressrelease_whole_document_dataset)

predicted_labels_pressreleases_documents = predicted_results_pressrelease_documents.predictions.argmax(-1) # Get the highest probability prediction
predicted_labels_pressreleases_documents = predicted_labels_pressreleases_documents.flatten().tolist()      # Flatten the predictions into a 1D list
predicted_labels_pressreleases_documents = [index_to_cmp[l] for l in predicted_labels_pressreleases_documents]
predicted_labels_text_pressreleases_documents = [text_label_dict[l] for l in predicted_labels_pressreleases_documents] # get text label for predicted categories

import tensorflow as tf

probabilities_pressreleases_documents = tf.math.softmax(predicted_results_pressrelease_documents.predictions, axis=-1).numpy()

prediction_df_pressreleases_documents = pd.DataFrame(probabilities_pressreleases_documents)
prediction_df_pressreleases_documents.columns = [f"Prob_{text_label_dict[str(i)]}_{i}" for i in range(1,21)]  # set text label as column names

# add penalty to "Political System" and "Political Authority" probabilities
prediction_df_pressreleases_documents['Prob_Political_System_6'] = prediction_df_pressreleases_documents['Prob_Political_System_6'] - 0.15
prediction_df_pressreleases_documents['Prob_Political_Authority_7'] = prediction_df_pressreleases_documents['Prob_Political_Authority_7'] - 0.4

prediction_df_pressreleases_documents.index = pressrelease_policy_df.index # set index from test df as index for prediction df

# get updated topic predictions
prediction_df_pressreleases_documents["issue"] = prediction_df_pressreleases_documents.idxmax(axis=1)

prediction_df_pressreleases_documents["issue"] = prediction_df_pressreleases_documents["issue"].replace('Prob_', '', regex=True)
prediction_df_pressreleases_documents["issue"] = prediction_df_pressreleases_documents["issue"].replace('_\d\d', '', regex=True)
prediction_df_pressreleases_documents["issue"] = prediction_df_pressreleases_documents["issue"].replace('_\d', '', regex=True)

prediction_df_pressreleases_documents.head()

prediction_df_pressreleases_documents.head()

# merge prediction df

final_prediction_df_pressreleases_documents = pd.concat([pressrelease_policy_df,prediction_df_pressreleases_documents], axis = 1)

final_prediction_df_pressreleases_documents.head()

final_pressrelease_df = pressrelease_non_policy_df.append(final_prediction_df_pressreleases_documents)

final_pressrelease_df = final_pressrelease_df.sort_index()

final_pressrelease_df.head()

#save prediction df
final_pressrelease_df.to_csv("Results_Pressreleases.csv",index=False, encoding="utf-8") # provide correct path

"""## Tweets (Parties)"""

data_path = "Corpus_TweetsParty.csv" # provide correct path

tweets_party_df = pd.read_csv(data_path, encoding="utf-8", index_col=0)

tweets_party_df.head()

tweets_party_df["type"] = ["Tipp" in text or "TIPP" in text or "Live" in text or "LIVE" in text for text in tweets_party_df["text"]]

tweets_party_df["type"] = tweets_party_df["type"].replace({True: "not policy-related", False: "policy-related"  })

#splitting datasets
tweets_party_non_policy_df = tweets_party_df[tweets_party_df["type"] == "not policy-related"]
tweets_party_policy_df = tweets_party_df[tweets_party_df["type"] == "policy-related"]

tweets_party_non_policy_df["issue"] = "NA"

tweets_party_policy_df['text'] = tweets_party_policy_df['text'].fillna("")

# convert unicode

tweets_party_policy_df['text'] =  [" ".join(re.sub(r'U\+([0-9a-fA-F]+)', lambda m: chr(int(m.group(1),16)), tweet).replace("<","").replace(">","").replace(r'\n', '\n').split()) for tweet in tweets_party_policy_df['text']]

print(tweets_party_policy_df['text'].iloc[0])

!pip3 install emoji

import emoji
import re

tweets_party_policy_df['text'] =  [ emoji.demojize(tweet) for tweet in tweets_party_policy_df['text']]  

print(tweets_party_policy_df['text'].iloc[0])

"""convert data to model input conform"""

tweets_party_document_sequences = tokenizer(list(tweets_party_policy_df["text"]), truncation=True, padding=True, max_length=512)

class MyDatasetRaw(torch.utils.data.Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.sequences.items()}
        return item

    def __len__(self):
        return len(self.sequences['input_ids'])

tweets_party_whole_document_dataset = MyDatasetRaw(tweets_party_document_sequences)

"""predict documents of party tweets df


"""

predicted_results_tweets_party_documents = trainer.predict(tweets_party_whole_document_dataset)

predicted_labels_tweets_party_documents = predicted_results_tweets_party_documents.predictions.argmax(-1) # Get the highest probability prediction
predicted_labels_tweets_party_documents = predicted_labels_tweets_party_documents.flatten().tolist()      # Flatten the predictions into a 1D list
predicted_labels_tweets_party_documents = [index_to_cmp[l] for l in predicted_labels_tweets_party_documents]
predicted_labels_text_tweets_party_documents = [text_label_dict[l] for l in predicted_labels_tweets_party_documents] # get text label for predicted categories

import tensorflow as tf

probabilities_tweets_party_documents = tf.math.softmax(predicted_results_tweets_party_documents.predictions, axis=-1).numpy()

prediction_df_tweets_party_documents = pd.DataFrame(probabilities_tweets_party_documents)
prediction_df_tweets_party_documents.columns = [f"Prob_{text_label_dict[str(i)]}_{i}" for i in range(1,21)]  # set text label as column names

# add penalty to "Political System" and "Political Authority" probabilities
prediction_df_tweets_party_documents['Prob_Political_System_6'] = prediction_df_tweets_party_documents['Prob_Political_System_6'] - 0.15
prediction_df_tweets_party_documents['Prob_Political_Authority_7'] = prediction_df_tweets_party_documents['Prob_Political_Authority_7'] - 0.4

prediction_df_tweets_party_documents.index = tweets_party_policy_df.index # set index from test df as index for prediction df

# get updated topic predictions
prediction_df_tweets_party_documents["issue"] = prediction_df_tweets_party_documents.idxmax(axis=1)

prediction_df_tweets_party_documents["issue"] = prediction_df_tweets_party_documents["issue"].replace('Prob_', '', regex=True)
prediction_df_tweets_party_documents["issue"] = prediction_df_tweets_party_documents["issue"].replace('_\d\d', '', regex=True)
prediction_df_tweets_party_documents["issue"] = prediction_df_tweets_party_documents["issue"].replace('_\d', '', regex=True)

prediction_df_tweets_party_documents.head()

final_prediction_df_tweets_party_documents = pd.concat([tweets_party_policy_df,prediction_df_tweets_party_documents], axis = 1)

final_prediction_df_tweets_party_documents.head()

final_tweets_party_df = tweets_party_non_policy_df.append(final_prediction_df_tweets_party_documents)

final_tweets_party_df["short"] = final_tweets_party_df.text.str.split().str.len() <= 5

final_tweets_party_df.loc[final_tweets_party_df['short'] == True, 'issue'] = "NA"

final_tweets_party_df.drop('short', inplace=True, axis=1)

final_tweets_party_df

final_tweets_party_df = final_tweets_party_df.sort_index()

#save prediction df
final_tweets_party_df.to_csv("Results_TweetsParty.csv",index=False, encoding="utf-8") # provide correct path

"""## Tweets (Persons)"""

data_path = "Corpus_TweetsPerson.csv" # provide correct path

tweets_person_df = pd.read_csv(data_path, encoding="latin-1", index_col=0)

tweets_person_df.head()

tweets_person_df["type"] = ["Tipp" in text or "TIPP" in text or "Live" in text or "LIVE" in text for text in tweets_person_df["text"]]

tweets_person_df["type"] = tweets_person_df["type"].replace({True: "not policy-related", False: "policy-related"  })

#splitting datasets
tweets_person_non_policy_df = tweets_person_df[tweets_person_df["type"] == "not policy-related"]
tweets_person_policy_df = tweets_person_df[tweets_person_df["type"] == "policy-related"]

tweets_person_non_policy_df["issue"] = "NA"

tweets_person_policy_df['text'] = tweets_person_policy_df['text'].fillna("")

# convert unicode

tweets_person_policy_df['text'] =  [" ".join(re.sub(r'U\+([0-9a-fA-F]+)', lambda m: chr(int(m.group(1),16)), tweet).replace("<","").replace(">","").replace(r'\n', '\n').split()) for tweet in tweets_person_policy_df['text']]

print(tweets_person_policy_df['text'].iloc[0])

!pip3 install emoji

import emoji
import re

tweets_person_policy_df['text'] =  [ emoji.demojize(tweet) for tweet in tweets_person_policy_df['text']]  

print(tweets_person_policy_df['text'].iloc[0])

tweets_person_document_sequences = tokenizer(list(tweets_person_policy_df["text"]), truncation=True, padding=True, max_length=512)

class MyDatasetRaw(torch.utils.data.Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.sequences.items()}
        return item

    def __len__(self):
        return len(self.sequences['input_ids'])

tweets_person_whole_document_dataset = MyDatasetRaw(tweets_person_document_sequences)

"""predict documents of person tweets df


"""

predicted_results_tweets_person_documents = trainer.predict(tweets_person_whole_document_dataset)

predicted_labels_tweets_person_documents = predicted_results_tweets_person_documents.predictions.argmax(-1) # Get the highest probability prediction
predicted_labels_tweets_person_documents = predicted_labels_tweets_person_documents.flatten().tolist()      # Flatten the predictions into a 1D list
predicted_labels_tweets_person_documents = [index_to_cmp[l] for l in predicted_labels_tweets_person_documents]
predicted_labels_text_tweets_person_documents = [text_label_dict[l] for l in predicted_labels_tweets_person_documents] # get text label for predicted categories

import tensorflow as tf

probabilities_tweets_person_documents = tf.math.softmax(predicted_results_tweets_person_documents.predictions, axis=-1).numpy()

prediction_df_tweets_person_documents = pd.DataFrame(probabilities_tweets_person_documents)
prediction_df_tweets_person_documents.columns = [f"Prob_{text_label_dict[str(i)]}_{i}" for i in range(1,21)]  # set text label as column names

# add penalty to "Political System" and "Political Authority" probabilities
prediction_df_tweets_person_documents['Prob_Political_System_6'] = prediction_df_tweets_person_documents['Prob_Political_System_6'] - 0.15
prediction_df_tweets_person_documents['Prob_Political_Authority_7'] = prediction_df_tweets_person_documents['Prob_Political_Authority_7'] - 0.4

prediction_df_tweets_person_documents.index = tweets_person_policy_df.index # set index from test df as index for prediction df

# get updated topic predictions
prediction_df_tweets_person_documents["issue"] = prediction_df_tweets_person_documents.idxmax(axis=1)

prediction_df_tweets_person_documents["issue"] = prediction_df_tweets_person_documents["issue"].replace('Prob_', '', regex=True)
prediction_df_tweets_person_documents["issue"] = prediction_df_tweets_person_documents["issue"].replace('_\d\d', '', regex=True)
prediction_df_tweets_person_documents["issue"] = prediction_df_tweets_person_documents["issue"].replace('_\d', '', regex=True)

prediction_df_tweets_person_documents.head()

final_prediction_df_tweets_person_documents = pd.concat([tweets_person_policy_df,prediction_df_tweets_person_documents], axis = 1)

final_prediction_df_tweets_person_documents.head()

final_tweets_person_df = tweets_person_non_policy_df.append(final_prediction_df_tweets_person_documents)

final_tweets_person_df["short"] = final_tweets_person_df.text.str.split().str.len() <= 5

final_tweets_person_df.loc[final_tweets_person_df['short'] == True, 'issue'] = "NA"

final_tweets_person_df.drop('short', inplace=True, axis=1)

final_tweets_person_df

final_tweets_person_df = final_tweets_person_df.sort_index()

#save prediction df
final_tweets_person_df.to_csv("/Results_TweetsPerson.csv",index=False, encoding="utf-8") # provide correct path

"""## Parliamentary Speeches"""

data_path = "Corpus_ParlSpeeches.csv" # provide correct path

parlspeech_df = pd.read_csv(data_path, encoding="cp1252", index_col=0)

parlspeech_df.head()

parlspeech_df['text'] = parlspeech_df['text'].fillna("")

# remove greeting formulas as they may deteriorate the classifier to use label "Political Authority"

parlspeech_df['text_nogreet'] = parlspeech_df['text'].replace("Herr\s\w+!+", "", regex=True)   
parlspeech_df['text_nogreet'] = parlspeech_df['text_nogreet'].replace("Frau\s\w+!", "", regex=True) 
parlspeech_df['text_nogreet'] = parlspeech_df['text_nogreet'].replace("Sehr geehrt[^\r\n!]", "", regex=True) 
parlspeech_df['text_nogreet'] = parlspeech_df['text_nogreet'].replace("Sehr verehrt[^\r\n!]", "", regex=True) 
parlspeech_df['text_nogreet'] = parlspeech_df['text_nogreet'].replace("Werte[^\r\n!]", "", regex=True)   
parlspeech_df['text_nogreet'] = parlspeech_df['text_nogreet'].replace("Liebe[^\r\n!]", "", regex=True)
parlspeech_df['text_nogreet'] = parlspeech_df['text_nogreet'].replace("Geschätzte[^\r\n!]", "", regex=True) 

parlspeech_df.head()

"""convert data to model input conform"""

parlspeech_document_sequences = tokenizer(list(parlspeech_df["text_nogreet"]), truncation=True, padding=True, max_length=512)

class MyDatasetRaw(torch.utils.data.Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.sequences.items()}
        return item

    def __len__(self):
        return len(self.sequences['input_ids'])

parlspeech_whole_document_dataset = MyDatasetRaw(parlspeech_document_sequences)

"""predict documents of parliamentary speech df


"""

predicted_results_parlspeech_documents = trainer.predict(parlspeech_whole_document_dataset)

predicted_labels_parlspeech_documents = predicted_results_parlspeech_documents.predictions.argmax(-1) # Get the highest probability prediction
predicted_labels_parlspeech_documents = predicted_labels_parlspeech_documents.flatten().tolist()      # Flatten the predictions into a 1D list
predicted_labels_parlspeech_documents = [index_to_cmp[l] for l in predicted_labels_parlspeech_documents]
predicted_labels_text_parlspeech_documents = [text_label_dict[l] for l in predicted_labels_parlspeech_documents] # get text label for predicted categories

import tensorflow as tf

probabilities_parlspeech_documents = tf.math.softmax(predicted_results_parlspeech_documents.predictions, axis=-1).numpy()

prediction_df_parlspeech_documents = pd.DataFrame(probabilities_parlspeech_documents)
prediction_df_parlspeech_documents.columns = [f"Prob_{text_label_dict[str(i)]}_{i}" for i in range(1,21)]  # set text label as column names

# add penalty to "Political System" and "Political Authority" probabilities
prediction_df_parlspeech_documents['Prob_Political_System_6'] = prediction_df_parlspeech_documents['Prob_Political_System_6'] - 0.15
prediction_df_parlspeech_documents['Prob_Political_Authority_7'] = prediction_df_parlspeech_documents['Prob_Political_Authority_7'] - 0.4

prediction_df_parlspeech_documents.index = parlspeech_df.index # set index from test df as index for prediction df

# get updated topic predictions
prediction_df_parlspeech_documents["issue"] = prediction_df_parlspeech_documents.idxmax(axis=1)

prediction_df_parlspeech_documents["issue"] = prediction_df_parlspeech_documents["issue"].replace('Prob_', '', regex=True)
prediction_df_parlspeech_documents["issue"] = prediction_df_parlspeech_documents["issue"].replace('_\d\d', '', regex=True)
prediction_df_parlspeech_documents["issue"] = prediction_df_parlspeech_documents["issue"].replace('_\d', '', regex=True)

prediction_df_parlspeech_documents.head()

final_prediction_df_parlspeech_documents = pd.concat([parlspeech_df,prediction_df_parlspeech_documents], axis = 1)

final_prediction_df_parlspeech_documents.head()

final_parlspeech_df = final_prediction_df_parlspeech_documents.sort_index()

final_parlspeech_df["short"] = final_parlspeech_df.text.str.split().str.len() <= 25

final_parlspeech_df.head()

final_parlspeech_df.loc[final_parlspeech_df['short'] == True, 'issue'] = "NA"

final_parlspeech_df.drop('short', inplace=True, axis=1)

final_parlspeech_df.drop('text_nogreet', inplace=True, axis=1)

final_parlspeech_df.head()

#save prediction df
final_parlspeech_df.to_csv("Results_ParlSpeeches.csv",index=False, encoding="utf-8") # provide correct path
